Add FunctionWrappers extension for differentiating through FunctionWrapper#2980
Add FunctionWrappers extension for differentiating through FunctionWrapper#2980ChrisRackauckas-Claude wants to merge 1 commit intoEnzymeAD:mainfrom
Conversation
…apper FunctionWrappers.jl wraps Julia functions behind C function pointers via ccall, which Enzyme cannot differentiate through. This adds an EnzymeFunctionWrappersExt extension that defines EnzymeRules for FunctionWrapper, extracting the original wrapped function and delegating to autodiff_deferred. This enables packages like NonlinearSolve.jl to use FunctionWrappers for their norecompile infrastructure without needing manual unwrapping at every call site that might use Enzyme. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/EnzymeFunctionWrappersExt.jl b/ext/EnzymeFunctionWrappersExt.jl
index 04abc32e..e6f71fa7 100644
--- a/ext/EnzymeFunctionWrappersExt.jl
+++ b/ext/EnzymeFunctionWrappersExt.jl
@@ -11,10 +11,10 @@ using Enzyme
# Helper to reconstruct an annotation with a cached primal value
@inline _reconstruct_arg(arg::Const, cached, overwritten::Bool) = arg
@inline function _reconstruct_arg(arg::Duplicated, cached, overwritten::Bool)
- overwritten && cached !== nothing ? Duplicated(cached, arg.dval) : arg
+ return overwritten && cached !== nothing ? Duplicated(cached, arg.dval) : arg
end
@inline function _reconstruct_arg(arg::BatchDuplicated, cached, overwritten::Bool)
- overwritten && cached !== nothing ? BatchDuplicated(cached, arg.dval) : arg
+ return overwritten && cached !== nothing ? BatchDuplicated(cached, arg.dval) : arg
end
@inline _reconstruct_arg(arg::Active, cached, overwritten::Bool) = arg
@@ -30,11 +30,11 @@ end
# Single rule for both IIP (Nothing return) and OOP FunctionWrappers.
# Extracts the wrapped function and delegates to autodiff_deferred.
function EnzymeRules.forward(
- config::EnzymeRules.FwdConfig,
- func::Const{<:FunctionWrapper},
- RT::Type{<:Annotation},
- args::Annotation...,
-)
+ config::EnzymeRules.FwdConfig,
+ func::Const{<:FunctionWrapper},
+ RT::Type{<:Annotation},
+ args::Annotation...,
+ )
raw_f = unwrap_fw(func.val)
# For IIP functions (Const{Nothing} return), needs_shadow is false but we
@@ -52,13 +52,13 @@ function EnzymeRules.forward(
# OOP: shadow is needed. Always use Duplicated for autodiff_deferred
# (it rejects DuplicatedNoNeed).
RealRt = eltype(RT)
- if EnzymeRules.needs_primal(config)
+ return if EnzymeRules.needs_primal(config)
res = Enzyme.autodiff_deferred(ForwardWithPrimal, Const(raw_f), Duplicated, args...)
# autodiff ForwardWithPrimal returns (derivs, primal)
if EnzymeRules.width(config) == 1
return Duplicated(res[2]::RealRt, res[1]::RealRt)
else
- return BatchDuplicated(res[2]::RealRt, res[1]::NTuple{EnzymeRules.width(config),RealRt})
+ return BatchDuplicated(res[2]::RealRt, res[1]::NTuple{EnzymeRules.width(config), RealRt})
end
else
res = Enzyme.autodiff_deferred(Forward, Const(raw_f), Duplicated, args...)
@@ -66,7 +66,7 @@ function EnzymeRules.forward(
if EnzymeRules.width(config) == 1
return res[1]::RealRt
else
- return res[1]::NTuple{EnzymeRules.width(config),RealRt}
+ return res[1]::NTuple{EnzymeRules.width(config), RealRt}
end
end
end
@@ -77,11 +77,11 @@ end
# augmented_primal: execute the forward pass, cache data for reverse
function EnzymeRules.augmented_primal(
- config::EnzymeRules.RevConfig,
- func::Const{<:FunctionWrapper{Ret}},
- RT::Type{<:Annotation},
- args::Annotation...,
-) where {Ret}
+ config::EnzymeRules.RevConfig,
+ func::Const{<:FunctionWrapper{Ret}},
+ RT::Type{<:Annotation},
+ args::Annotation...,
+ ) where {Ret}
raw_f = unwrap_fw(func.val)
ow = EnzymeRules.overwritten(config)
nargs = length(args)
@@ -129,12 +129,12 @@ end
# reverse for IIP (Nothing return): accumulate gradients into dval arrays
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- func::Const{<:FunctionWrapper{Nothing}},
- ::Type{<:Const{Nothing}},
- tape,
- args::Annotation...,
-)
+ config::EnzymeRules.RevConfig,
+ func::Const{<:FunctionWrapper{Nothing}},
+ ::Type{<:Const{Nothing}},
+ tape,
+ args::Annotation...,
+ )
raw_f, cached_args = tape
ow = EnzymeRules.overwritten(config)
nargs = length(args)
@@ -154,12 +154,12 @@ end
# reverse for OOP with Active return: return scaled per-arg gradients
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- func::Const{<:FunctionWrapper{Ret}},
- dret::Active,
- tape,
- args::Annotation...,
-) where {Ret}
+ config::EnzymeRules.RevConfig,
+ func::Const{<:FunctionWrapper{Ret}},
+ dret::Active,
+ tape,
+ args::Annotation...,
+ ) where {Ret}
raw_f, cached_args = tape
ow = EnzymeRules.overwritten(config)
nargs = length(args)
@@ -181,12 +181,12 @@ end
# reverse for OOP with Duplicated/Const return type (non-Active)
function EnzymeRules.reverse(
- config::EnzymeRules.RevConfig,
- func::Const{<:FunctionWrapper{Ret}},
- dret::Type{<:Annotation},
- tape,
- args::Annotation...,
-) where {Ret}
+ config::EnzymeRules.RevConfig,
+ func::Const{<:FunctionWrapper{Ret}},
+ dret::Type{<:Annotation},
+ tape,
+ args::Annotation...,
+ ) where {Ret}
if !(dret <: Const)
raw_f, cached_args = tape
ow = EnzymeRules.overwritten(config)
diff --git a/test/ext/functionwrappers.jl b/test/ext/functionwrappers.jl
index 94b4e9c8..79af62bf 100644
--- a/test/ext/functionwrappers.jl
+++ b/test/ext/functionwrappers.jl
@@ -10,19 +10,23 @@ using FunctionWrappers: FunctionWrapper
f_oop(x, p) = p[1] * x^2
@testset "IIP Forward Mode" begin
- fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!)
+ fw = FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}(f!)
u = [2.0]; du = zeros(1); p = [3.0]
ddu = zeros(1); du_u = [1.0]
# Differentiate through FunctionWrapper
- Enzyme.autodiff(Forward, fw, Const{Nothing},
- Duplicated(du, ddu), Duplicated(u, du_u), Const(p))
+ Enzyme.autodiff(
+ Forward, fw, Const{Nothing},
+ Duplicated(du, ddu), Duplicated(u, du_u), Const(p)
+ )
# Compare with raw function
u2 = [2.0]; du2 = zeros(1); ddu2 = zeros(1); du_u2 = [1.0]
- Enzyme.autodiff(Forward, f!, Const{Nothing},
- Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p))
+ Enzyme.autodiff(
+ Forward, f!, Const{Nothing},
+ Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p)
+ )
@test ddu ≈ ddu2
# ddu[1] should be d/du(p*u^2) * du_u = 3.0 * 2 * 2.0 * 1.0 = 12.0
@@ -30,18 +34,22 @@ using FunctionWrappers: FunctionWrapper
end
@testset "IIP Reverse Mode" begin
- fw = FunctionWrapper{Nothing,Tuple{Vector{Float64},Vector{Float64},Vector{Float64}}}(f!)
+ fw = FunctionWrapper{Nothing, Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}}}(f!)
u = [2.0]; du = zeros(1); p = [3.0]
ddu = [1.0]; du_u = zeros(1)
- Enzyme.autodiff(Reverse, fw, Const{Nothing},
- Duplicated(du, ddu), Duplicated(u, du_u), Const(p))
+ Enzyme.autodiff(
+ Reverse, fw, Const{Nothing},
+ Duplicated(du, ddu), Duplicated(u, du_u), Const(p)
+ )
# Compare with raw function
u2 = [2.0]; du2 = zeros(1); ddu2 = [1.0]; du_u2 = zeros(1)
- Enzyme.autodiff(Reverse, f!, Const{Nothing},
- Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p))
+ Enzyme.autodiff(
+ Reverse, f!, Const{Nothing},
+ Duplicated(du2, ddu2), Duplicated(u2, du_u2), Const(p)
+ )
@test du_u ≈ du_u2
# du/du[1] of (du[1] = p[1]*u[1]^2) with seed ddu[1]=1.0:
@@ -50,17 +58,21 @@ using FunctionWrappers: FunctionWrapper
end
@testset "OOP Forward Mode" begin
- fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop)
+ fw_oop = FunctionWrapper{Float64, Tuple{Float64, Vector{Float64}}}(f_oop)
x = 3.0; p = [2.0]
dx = 1.0
- res = Enzyme.autodiff(Forward, fw_oop, Duplicated,
- Duplicated(x, dx), Const(p))
+ res = Enzyme.autodiff(
+ Forward, fw_oop, Duplicated,
+ Duplicated(x, dx), Const(p)
+ )
# Compare with raw function
- res2 = Enzyme.autodiff(Forward, f_oop, Duplicated,
- Duplicated(x, dx), Const(p))
+ res2 = Enzyme.autodiff(
+ Forward, f_oop, Duplicated,
+ Duplicated(x, dx), Const(p)
+ )
@test res[1] ≈ res2[1]
# d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0
@@ -68,16 +80,20 @@ using FunctionWrappers: FunctionWrapper
end
@testset "OOP Reverse Mode" begin
- fw_oop = FunctionWrapper{Float64,Tuple{Float64,Vector{Float64}}}(f_oop)
+ fw_oop = FunctionWrapper{Float64, Tuple{Float64, Vector{Float64}}}(f_oop)
x = 3.0; p = [2.0]
- res = Enzyme.autodiff(Reverse, fw_oop, Active,
- Active(x), Const(p))
+ res = Enzyme.autodiff(
+ Reverse, fw_oop, Active,
+ Active(x), Const(p)
+ )
# Compare with raw function
- res2 = Enzyme.autodiff(Reverse, f_oop, Active,
- Active(x), Const(p))
+ res2 = Enzyme.autodiff(
+ Reverse, f_oop, Active,
+ Active(x), Const(p)
+ )
@test res[1][1] ≈ res2[1][1]
# d/dx(p*x^2) = 2*p*x = 2*2.0*3.0 = 12.0 |
Session State Summary (for continuation)What this PR doesAdds Problem: FunctionWrappers.jl wraps Julia functions behind C function pointers via Solution: The extension extracts the original function via Files created/modified
Test results (local, Julia 1.10)All 8 tests pass:
Design decisions
Downstream dependencySciML/NonlinearSolve.jl PR #838 depends on this PR. NonlinearSolve currently has ~73 lines of manual Enzyme workaround code ( CI statusCI workflows show Local repo
Co-Authored-By: Chris Rackauckas accounts@chrisrackauckas.com |
Summary
EnzymeFunctionWrappersExtextension that defines forward and reverse mode EnzymeRules forFunctionWrapper{Ret,Args}fw.obj[]and delegates toautodiff_deferred, bypassing theccallbarrier that Enzyme cannot differentiate throughAutoSpecializeCallable) without manual unwrapping at every Enzyme call siteMotivation
FunctionWrappers.jl wraps Julia functions behind C function pointers using
ccall/llvmcall. Enzyme cannot differentiate through this mechanism, throwingEnzymeMutabilityException. NonlinearSolve.jl (PR #838) uses FunctionWrappers for its norecompile infrastructure and currently works around this by manually unwrapping at every call site — a fragile, whack-a-mole approach across 4+ files.This extension solves the problem at the source: Enzyme automatically differentiates through the original wrapped function.
Implementation
Forward mode rule — single method handles both IIP (
Nothingreturn) and OOP:RT <: Const): Always runsautodiff_deferred(Forward, ...)to propagate tangents into argument shadow arraysDuplicatedinternally (sinceautodiff_deferredrejectsDuplicatedNoNeed), with type assertions for return stabilityReverse mode rules —
augmented_primal+ 3reversemethods:augmented_primal: Executes primal via unwrapped function, caches copies of overwritten argsautodiff_deferred(Reverse, ...)dret.valusing type-stable helperautodiff_deferredfor gradient accumulationDesign decisions:
func::Const{<:FunctionWrapper}(wrapper itself not differentiated) — covers NonlinearSolve and all standard usesBatchDuplicatedannotations passed through toautodiff_deferredFiles changed
Project.toml— FunctionWrappers weakdep, extension, compat (1.1+), extrasext/EnzymeFunctionWrappersExt.jl— New extension module (209 lines)test/ext/functionwrappers.jl— 8 tests: IIP/OOP × Forward/Reverse, all verified against raw functiontest/Project.toml— FunctionWrappers test dependencyTest plan
🤖 Generated with Claude Code